﻿#include <cstring>
#include <jvmti.h>
#include <string>
#include "cpu_id.h"
#include "rsa/rsa.h"
#include "base64/base64.h"
#include "org_magicwall_universal_core_BytecodeDecrypt.h"
#include "com_fenquen_sourceguard_dencrypt_Dencrypt.h"

#define IS_ENCRYPTED(byte_code) \
 (byte_code[0] == 0x01) && \
((byte_code[1] == 0x00)||(byte_code[1]==0xdd))&& \
(byte_code[2] == 0x01) && \
(byte_code[3] == 0x00)          \

#define NEED_ACTIVATION_CODE(byte_code) \
 (byte_code[1]==0xdd)\

#define RESTORE_MAGIC(origin_byte_code) \
origin_byte_code[0] = 0xca; \
origin_byte_code[1] = 0xfe; \
origin_byte_code[2] = 0xba; \
origin_byte_code[3] = 0xbe              \

#define START_INDEX 4

#define CHANGE_MAGIC(origin_byte_code, bind_machine) \
origin_byte_code[0] = 0x01;                         \
if (bind_machine){                                  \
origin_byte_code[1] = 0xdd;\
}else{                                              \
origin_byte_code[1] = 0x00;\
}                                                    \
origin_byte_code[2] = 0x01; \
origin_byte_code[3] = 0x00 \

std::string machine_id;
std::string activation_code;

void a(const unsigned char *old_data, int old_data_len, unsigned char *new_data) {
    // 说明是加密过的
    if (IS_ENCRYPTED(old_data)) {
        if (NEED_ACTIVATION_CODE(old_data)) {
            if (activation_code.length() == 0) {
                printf("需要激活码");
                exit(0);
            }

            char machine_id0[100] = {0};


            unsigned char *rsaEncrypted = base64_decode(activation_code.c_str());

            public_decrypt(rsaEncrypted, sizeof(rsaEncrypted),
                           reinterpret_cast<unsigned char *>(machine_id0));

            if (machine_id != machine_id0) {
                printf("激活码不正确");
                exit(0);
            }

            free(rsaEncrypted);
        }

        RESTORE_MAGIC(new_data);

        for (int i = 4; i < old_data_len; ++i) {
            unsigned char low = old_data[i] & 0x0f;
            unsigned char high = (old_data[i] >> 4) & 0x0f;

            new_data[i] = (low << 4) | high;
        }
    } else { // 如果不是加密过的那么原样呈现
        for (int i = 0; i < old_data_len; ++i) {
            new_data[i] = old_data[i];
        }
    }
}

JNIEXPORT jbyteArray JNICALL Java_com_fenquen_sourceguard_dencrypt_Dencrypt_decrypt0
        (JNIEnv *jniEnv, jclass jclazz, jbyteArray inJNIArray) {
    return Java_org_magicwall_universal_1core_BytecodeDecrypt_decrypt0(jniEnv, jclazz, inJNIArray);
}

JNIEXPORT jbyteArray JNICALL Java_com_fenquen_sourceguard_dencrypt_Dencrypt_encrypt0
        (JNIEnv *jniEnv, jclass jclazz, jbyteArray inJNIArray, jboolean bind_machine) {
    // 得到c/cpp维度的数组的指针
    // 第二个参数 isCopy 代表返回的数组指针是原始数组，还是拷贝原始数据到临时缓冲区的指针
    // 如果是 false 代表原始数组指针 true 代表临时缓冲区数组指针
    jbyte *inCArray = jniEnv->GetByteArrayElements(inJNIArray, nullptr);

    // 得到java中传入的数组的长度
    const jsize arrLen = jniEnv->GetArrayLength(inJNIArray);

    // 使用malloc生成c/cpp维度的结果数组
    jbyte *outCArray = (jbyte *) malloc(arrLen * sizeof(jbyte));

    CHANGE_MAGIC(outCArray, bind_machine);

    for (int a = START_INDEX; arrLen > a; a++) {
        unsigned char low = inCArray[a] & 0x0f;
        unsigned char high = (inCArray[a] >> 4) & 0x0f;

        outCArray[a] = (low << 4) | high;
    }

    // 和上边的ReleaseByteArrayElements配对
    jniEnv->ReleaseByteArrayElements(inJNIArray, inCArray, 0);

    // 生成了java维度的结果数组
    jbyteArray outJNIArray = jniEnv->NewByteArray(arrLen);

    // 值的传递
    jniEnv->SetByteArrayRegion(outJNIArray, 0, arrLen, outCArray);

    free(outCArray);

    return outJNIArray;
}

JNIEXPORT jbyteArray JNICALL Java_org_magicwall_universal_1core_BytecodeDecrypt_decrypt0
        (JNIEnv *jniEnv, jclass jclazz, jbyteArray inJNIArray) {
    jbyte *inCArray = jniEnv->GetByteArrayElements(inJNIArray, nullptr);

    const jsize old_data_len = jniEnv->GetArrayLength(inJNIArray);

    jbyte *outCArray = (jbyte *) malloc(old_data_len * sizeof(jbyte));

    a(reinterpret_cast<const unsigned char *>(inCArray), old_data_len, reinterpret_cast<unsigned char *>(outCArray));

    /* if (inCArray[0] == 0x01 && inCArray[1] == 0x00 && inCArray[2] == 0x01 && inCArray[3] == 0x00) {
         outCArray[0] = 0xca;
         outCArray[1] = 0xfe;
         outCArray[2] = 0xba;
         outCArray[3] = 0xbe;

         for (int a = 4; old_data_len > a; a++) {
             unsigned char low = inCArray[a] & 0x0f;
             unsigned char high = (inCArray[a] >> 4) & 0x0f;

             outCArray[a] = (low << 4) | high;
         }
     } else {
         for (int a = 0; old_data_len > a; a++) {
             outCArray[a] = inCArray[a];
         }
     }*/


    jniEnv->ReleaseByteArrayElements(inJNIArray, inCArray, 0);

    jbyteArray outJNIArray = jniEnv->NewByteArray(old_data_len);

    jniEnv->SetByteArrayRegion(outJNIArray, 0, old_data_len, outCArray);

    free(outCArray);

    return outJNIArray;
}

// jvm类加载的回调函数,可以在这个时机解密加密过的字节码内容
void JNICALL MyClassFileLoadHook(jvmtiEnv *jvmti_env,
                                 JNIEnv *jni_env,
                                 jclass class_being_redefined,
                                 jobject loader,
                                 const char *name,
                                 jobject protection_domain,
                                 jint class_data_len,
                                 const unsigned char *class_data,
                                 jint *new_class_data_len,
                                 unsigned char **new_class_data) {
    *new_class_data_len = class_data_len;
    jvmti_env->Allocate(class_data_len, new_class_data);


    unsigned char *my_data = *new_class_data;

    a(class_data, class_data_len, my_data);
}

// 非常像利用反射调用exception的printstacktrace函数
void printStackTrace(JNIEnv *jniEnv, jobject exception) {
    jclass throwable_class = jniEnv->FindClass("java/lang/Throwable");
    jmethodID print_method = jniEnv->GetMethodID(throwable_class, "printStackTrace", "()V");
    jniEnv->CallVoidMethod(exception, print_method);
}

// 异常回调函数
void JNICALL Callback_JVMTI_EVENT_EXCEPTION(jvmtiEnv *jvmti_env,
                                            JNIEnv *jni_env,
                                            jthread thread,
                                            jmethodID method,
                                            jlocation location,
                                            jobject exceptionInstance,
                                            jmethodID catch_method,
                                            jlocation catch_location) {
    // char *class_name;

    // get the clazz by instance
    // jclass exception_class = jni_env->GetObjectClass(exceptionInstance);

    // get the full name of the clazz and assigen it to class_name
    // jvmti_env->GetClassSignature(exception_class, &class_name, nullptr);

    // printf("Exception in class: %s,%d \n", class_name, __LINE__);

    // printStackTrace(jni_env, exceptionInstance);
}

// options对应的是agentlib对应的参数例如-agentlib:sourceguard.so=aaaa中的那个aaaa的
JNIEXPORT jint JNICALL Agent_OnLoad(JavaVM *javaVM, char *options, void *reserved) {
    // 得到JVMTI environment
    jvmtiEnv *pJvmtiEnv = nullptr;
    jint resultCode = javaVM->GetEnv((void **) &pJvmtiEnv, JVMTI_VERSION);
    if (resultCode != JNI_OK) {
        printf("%s\n", "couldn't get jvmti environment");
        return JNI_ERR;
    }

    char activation_code_arr[10] = {0};
    char *activation_code_ = activation_code_arr;
    pJvmtiEnv->GetSystemProperty("activation_code", &activation_code_);
    printf("activation_code:%s\n", activation_code_);
    if (activation_code_) {
        activation_code = std::string(activation_code_);
    }

    machine_id = get_mac_and_cpu_id();
    if (machine_id == (INVALID_MAC_AND_CPU_ID)) {
        printf("不能得到机器码machine_id,系统终止");
        //javaVM->DestroyJavaVM();
        exit(0);
    }

    printf("机器码machine_id:%s\n", machine_id.c_str());


    // 注册功能
    jvmtiCapabilities jvmtiCapabilities_;
    (void) memset(&jvmtiCapabilities_, 0, sizeof(jvmtiCapabilities));

    jvmtiCapabilities_.can_generate_exception_events = JVMTI_ENABLE;

    jvmtiCapabilities_.can_generate_all_class_hook_events = JVMTI_ENABLE;
    jvmtiCapabilities_.can_tag_objects = JVMTI_ENABLE;
    jvmtiCapabilities_.can_generate_object_free_events = JVMTI_ENABLE;
    jvmtiCapabilities_.can_get_source_file_name = JVMTI_ENABLE;
    jvmtiCapabilities_.can_get_line_numbers = JVMTI_ENABLE;
    jvmtiCapabilities_.can_generate_vm_object_alloc_events = JVMTI_ENABLE;


    jvmtiError error = pJvmtiEnv->AddCapabilities(&jvmtiCapabilities_);
    if (error != JVMTI_ERROR_NONE) {
        fprintf(stderr, "ERROR: Unable to AddCapabilities JVMTI");
        return error;
    }

    // 设置各个JVM事件回调函数
    jvmtiEventCallbacks jvmtiEventCallbacks_;

    jvmtiEventCallbacks_.Exception = &Callback_JVMTI_EVENT_EXCEPTION;

    jvmtiEventCallbacks_.ClassFileLoadHook = &MyClassFileLoadHook;

    error = pJvmtiEnv->SetEventCallbacks(&jvmtiEventCallbacks_, (jint) sizeof(jvmtiEventCallbacks_));
    if (error != JVMTI_ERROR_NONE) {
        fprintf(stderr, "ERROR: Unable to SetEventCallbacks JVMTI!");
        return error;
    }

    // 配置对应事件的回调函数后,还要显式配置enable
    error = pJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_EXCEPTION, nullptr);
    if (error != JVMTI_ERROR_NONE) {
        printf("ERROR: Unable to SetEventNotificationMode JVMTI_EVENT_EXCEPTION,the error code=%d\n", error);
        return error;
    }

    error = pJvmtiEnv->SetEventNotificationMode(JVMTI_ENABLE, JVMTI_EVENT_CLASS_FILE_LOAD_HOOK, nullptr);
    if (error != JVMTI_ERROR_NONE) {
        printf("ERROR: Unable to SetEventNotificationMode JVMTI_EVENT_EXCEPTION,the error code=%d\n", error);
        return error;
    }

    return JNI_OK;
}

JNIEXPORT jint JNICALL Agent_OnAttach(JavaVM *vm, char *options, void *reserved) {
    printf("%s\n", "Agent_OnAttach");
    return JNI_OK;
}

JNIEXPORT void JNICALL Agent_OnUnload(JavaVM *vm) {
    printf("%s\n", "Agent_OnUnload");
}